{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd /data/ajay/contracode\n",
    "%pwd\n",
    "!pip uninstall -y pandas\n",
    "!pip install pandas sklearn scipy matplotlib seaborn numpy plotnine altair"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/paras/miniconda3/lib/python3.8/site-packages/sklearn/utils/deprecation.py:143: FutureWarning: The sklearn.manifold.t_sne module is  deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.manifold. Anything that cannot be imported from sklearn.manifold is now part of the private API.\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import plotnine as p9\n",
    "from sklearn.datasets import load_digits\n",
    "from scipy.spatial.distance import pdist\n",
    "from sklearn.manifold.t_sne import _joint_probabilities\n",
    "from scipy import linalg\n",
    "from sklearn.metrics import pairwise_distances\n",
    "from scipy.spatial.distance import squareform\n",
    "from sklearn.manifold import TSNE\n",
    "from matplotlib import pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import altair as alt\n",
    "\n",
    "sns.set(rc={'figure.figsize':(11.7,8.27)})\n",
    "palette = sns.color_palette(\"bright\", 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "moco_path = \"data/tsne/tsne_out_embedded_grouped_hidden.pickle\"\n",
    "!ls -lah {moco_path}\n",
    "with open(moco_path, 'rb') as f:\n",
    "    x = f.read()\n",
    "    in_tsne_embeddings = pickle.loads(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matches = [\n",
    " ('init', 282),\n",
    " ('parse', 242),\n",
    " ('get', 186),\n",
    " ('create', 181),\n",
    " ('validate', 145),\n",
    " ('run', 144),\n",
    " ('update', 128),\n",
    " ('extend', 127),\n",
    " ('merge', 111),\n",
    " ('set', 109),\n",
    " ('render', 105),\n",
    " ('transform', 103),\n",
    " ('resolve', 100),\n",
    " ('main', 99),\n",
    " ('request', 98),\n",
    " ('log', 98),\n",
    " ('add', 95),\n",
    " ('load', 93),\n",
    " ('format', 90),\n",
    " ('client', 90),\n",
    " ('compile', 87),\n",
    " ('start', 87),\n",
    " ('find', 84),\n",
    " ('normalize', 83),\n",
    " ('clone', 81)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_tsne_dict(data, include=None, isolate_color=None, n_iter=10000, perplexity=50, mean=True):\n",
    "    all_data = []\n",
    "    labels = []\n",
    "    for item in data:\n",
    "        if include is None or item['func_name'] in include:\n",
    "            label = item['func_name'] if isolate_color is None or item['func_name'] in isolate_color else 'Other'\n",
    "            labels.append(label)\n",
    "            all_data.append(item['embedding'])\n",
    "\n",
    "    tsne = TSNE(n_iter=n_iter, perplexity=perplexity, n_jobs=-1, learning_rate=25.)\n",
    "    if mean:\n",
    "        z = [x.mean(0) for x in all_data]\n",
    "    else:\n",
    "        z = [x.flatten() for x in all_data]\n",
    "    tsne_results = tsne.fit_transform(z)\n",
    "    out_data = []\n",
    "    for (x, y), label in zip(tsne_results, labels):\n",
    "        out_data.append(dict(x=x, y=y, label=str(label)))\n",
    "    df = pd.DataFrame(out_data)\n",
    "    # plot = p9.ggplot(p9.aes('x', 'y'), df) + p9.geom_point(p9.aes(color='label'), alpha=0.8) + p9.theme_classic()\n",
    "    return df\n",
    "\n",
    "not_grey_list = ['validate', 'normalize', 'compile']\n",
    "df = compute_tsne_dict(in_tsne_embeddings, include=[x[0] for x in matches], isolate_color=not_grey_list,\n",
    "                       n_iter=4000, perplexity=32, mean=True)\n",
    "\n",
    "alt.Chart(df).mark_circle(size=60).encode(\n",
    "    x='x',\n",
    "    y='y',\n",
    "    color='label',\n",
    "    tooltip=['label']\n",
    ").interactive()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Flat embedding file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "moco_path = \"data/tsne/moco_embed_tsne_embeddings.pickle\"\n",
    "!ls -lah {moco_path}\n",
    "with open(moco_path, 'rb') as f:\n",
    "    x = f.read()\n",
    "    in_tsne_embeddings_file = pickle.loads(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_embeddings = [[np.asarray(z[1]) for z in x if z is not None] for x in in_tsne_embeddings_file]\n",
    "sorted_embeddings = sorted(filtered_embeddings, key=lambda x: len(x))[-128:]\n",
    "top20_samples = sorted_embeddings\n",
    "\n",
    "def compute_tsne(data, exclude=[], n_iter=10000, perplexity=50, mean=True):\n",
    "    all_data = []\n",
    "    labels = []\n",
    "    for group_idx, group in enumerate(data):\n",
    "        labels.extend([group_idx] * len(group))\n",
    "        all_data.extend(group)\n",
    "\n",
    "    tsne = TSNE(n_iter=n_iter, perplexity=perplexity, n_jobs=-1, early_exaggeration=5.)\n",
    "    if mean:\n",
    "        z = [x.mean(0).mean(0) for x in all_data]\n",
    "    else:\n",
    "        z = [x.flatten() for x in all_data]\n",
    "    \n",
    "    tsne_results = tsne.fit_transform(z)\n",
    "    out_data = []\n",
    "    for (x, y), label in zip(tsne_results, labels):\n",
    "        out_data.append(dict(x=x, y=y, label=str(label)))\n",
    "    df = pd.DataFrame(out_data)\n",
    "    plot = p9.ggplot(p9.aes('x', 'y'), df) + p9.geom_point(p9.aes(color='label'), alpha=0.8) + p9.theme_classic()\n",
    "    return df, plot\n",
    "\n",
    "for perp in [64]:\n",
    "    %time plot = compute_tsne(top20_samples, n_iter=1000, perplexity=perp, mean=False)[1]\n",
    "    print(perp, plot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
